{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:26:50.172208Z",
     "start_time": "2021-10-02T12:26:46.994378Z"
    }
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.simplefilter('ignore')\n",
    "\n",
    "import os\n",
    "import re\n",
    "import glob\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from transformers import (AutoModel, AutoModelForMaskedLM,\n",
    "                          AutoTokenizer, LineByLineTextDataset,\n",
    "                          DataCollatorForLanguageModeling,\n",
    "                          Trainer, TrainingArguments)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:26:50.283368Z",
     "start_time": "2021-10-02T12:26:50.174789Z"
    }
   },
   "outputs": [],
   "source": [
    "text_files = glob.glob('juben/*/*.txt')\n",
    "text_list = list()\n",
    "for f in text_files:\n",
    "    with open(f, 'r') as handler:\n",
    "        texts = handler.read().split('\\n')\n",
    "        for text in texts:\n",
    "            text = text.strip()\n",
    "            if len(text) > 0 and (not re.match(r'[0-9]', text)):\n",
    "                text_list.append(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:26:50.301898Z",
     "start_time": "2021-10-02T12:26:50.285853Z"
    }
   },
   "outputs": [],
   "source": [
    "text_str = '\\n'.join(text_list)\n",
    "with open('juben/text.txt', 'w') as handler:\n",
    "    handler.write(text_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:27:09.231969Z",
     "start_time": "2021-10-02T12:27:06.389057Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('./roberta-wwm-ext-pretrain/tokenizer_config.json',\n",
       " './roberta-wwm-ext-pretrain/special_tokens_map.json',\n",
       " './roberta-wwm-ext-pretrain/vocab.txt',\n",
       " './roberta-wwm-ext-pretrain/added_tokens.json',\n",
       " './roberta-wwm-ext-pretrain/tokenizer.json')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_name = 'hfl/chinese-roberta-wwm-ext'\n",
    "\n",
    "model = AutoModelForMaskedLM.from_pretrained(model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.save_pretrained('./roberta-wwm-ext-pretrain')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:27:55.566240Z",
     "start_time": "2021-10-02T12:27:51.000895Z"
    }
   },
   "outputs": [],
   "source": [
    "train_dataset = LineByLineTextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"juben/text.txt\",  # mention train text file here\n",
    "    block_size=256)\n",
    "\n",
    "valid_dataset = LineByLineTextDataset(\n",
    "    tokenizer=tokenizer,\n",
    "    file_path=\"juben/text.txt\",  # mention valid text file here\n",
    "    block_size=256)\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer, mlm=True, mlm_probability=0.15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:28:19.197827Z",
     "start_time": "2021-10-02T12:28:19.126268Z"
    }
   },
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    output_dir=\"./roberta-wwm-ext-pretrain\",  # select model path for checkpoint\n",
    "    overwrite_output_dir=True,\n",
    "    num_train_epochs=3,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    # gradient_accumulation_steps=2,\n",
    "    evaluation_strategy='steps',\n",
    "    save_total_limit=2,\n",
    "    eval_steps=5000,\n",
    "    save_steps=5000,\n",
    "    metric_for_best_model='eval_loss',\n",
    "    greater_is_better=False,\n",
    "    load_best_model_at_end=True,\n",
    "    prediction_loss_only=True,\n",
    "    report_to=\"none\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:54:54.776496Z",
     "start_time": "2021-10-02T12:28:34.870416Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 52793\n",
      "  Num Epochs = 3\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 9900\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='9900' max='9900' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [9900/9900 26:14, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>5000</td>\n",
       "      <td>1.553500</td>\n",
       "      <td>1.420220</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Evaluation *****\n",
      "  Num examples = 52793\n",
      "  Batch size = 16\n",
      "Saving model checkpoint to ./roberta-wwm-ext-pretrain/checkpoint-5000\n",
      "Configuration saved in ./roberta-wwm-ext-pretrain/checkpoint-5000/config.json\n",
      "Model weights saved in ./roberta-wwm-ext-pretrain/checkpoint-5000/pytorch_model.bin\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n",
      "Loading best model from ./roberta-wwm-ext-pretrain/checkpoint-5000 (score: 1.4202196598052979).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=9900, training_loss=1.595300934146149, metrics={'train_runtime': 1576.0117, 'train_samples_per_second': 100.494, 'train_steps_per_second': 6.282, 'total_flos': 6458052129123264.0, 'train_loss': 1.595300934146149, 'epoch': 3.0})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset)\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-10-02T12:55:21.044635Z",
     "start_time": "2021-10-02T12:55:20.216656Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./roberta-wwm-ext-pretrain\n",
      "Configuration saved in ./roberta-wwm-ext-pretrain/config.json\n",
      "Model weights saved in ./roberta-wwm-ext-pretrain/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "trainer.save_model('./roberta-wwm-ext-pretrain')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
